#!/usr/bin/env python
# -*- coding: utf-8 -*-
# @Time    : 2020/4/1 15:28
# @USER    : Shengji He
# @File    : metrics.py
# @Software: PyCharm
# @Version  : Python-
# @TASK:
import math
import numpy as np
import scipy.ndimage as scim


def factors(n):
    """
    Decompose integer number in factors
    :param n:
    :return:
    """
    result = []
    for i in range(2, n + 1):
        s = 0
        while n / i == math.floor(n / float(i)):
            n = n / float(i)
            s += 1
        if s > 0:
            for k in range(s):
                result.append(i)
            if n == 1:
                return result


def compute_gradient_image(image):
    """
    Compute gradient image
    :param image: [height, width]
    :return:
    """
    grad_image = np.gradient(image)
    grad_image = np.array(grad_image).reshape(2, *image.shape)
    # grad_image = np.array(grad_image).reshape(2, image.shape[0], image.shape[1])
    sqr_grad_image = np.sqrt(grad_image[0, :, :] ** 2 + grad_image[1, :, :] ** 2)
    return sqr_grad_image


def init_gaussian_kernel(kernel_size, sigma):
    """
    Initialize gaussian kernel
    :param kernel_size: kernel size
    :param sigma:
    :return:
    """
    gauss_kernel = np.zeros((kernel_size, kernel_size), dtype=np.float64)
    c = 0.5 * (kernel_size - 1)
    for i in range(kernel_size):
        for j in range(kernel_size):
            gauss_kernel[i, j] = (1 / (2 * np.pi * sigma ** 2)) * \
                                 np.exp(-(((i - c) ** 2) + ((j - c) ** 2)) / (2 * sigma ** 2))

    #  Normalize coefficients
    totalSum = np.sum(gauss_kernel)
    gauss_kernel *= 1.0 / totalSum

    return gauss_kernel


# metric

def calc_snr(oracle, image):
    """
    SNR ---> signal to noise ratio

    Notation:
        r(x,y)  --->  reference image
        t(x,y)  --->  test image
        nx , ny --->  number of rown and columns
    SNR = 10 * log_{10}( ( sum_{x}sum_{y} [ r(x,y) ]^2 ) /
                   sum_{x}sum_{y} [ r(x,y) - t(x,y) ]^2 )
    :param oracle:
    :param image:
    :return:
    """
    num = np.sum(oracle * oracle)
    den = np.sum((oracle - image) * (oracle - image))
    SNR = 10 * np.log10(num / den)
    return SNR


def calc_psnr(oracle, image):
    """
    PSNR ---> peak signal to noise ratio expressed in dB

    Notation:
        r(x,y)  --->  reference image
        t(x,y)  --->  test image
        nx , ny --->  number of rown and columns
    PSNR = 10 * log_{10}( ( max( r(x,y) ]^2 ) /
    ( 1/( nx * ny ) * sum_{x}sum_{y} [ r(x,y) - t(x,y) ]^2 )
    :param oracle:
    :param image:
    :return:
    """
    nx, ny = image.shape
    factor = 1.0 / (nx * ny)
    num = (np.max(oracle)) ** 2
    den = factor * np.sum((oracle - image) * (oracle - image))
    PSNR = 10 * np.log10(num / den)
    return PSNR


def calc_rmse(oracle, image):
    """
    RMSE ---> root mean square error

    Notation:
        r(x,y)  --->  reference image
        t(x,y)  --->  test image
        nx , ny --->  number of rown and columns
    RMSE = 1/( nx * ny ) * sum_{x}sum_{y} [ r(x,y) - t(x,y) ]^2
    :param oracle:
    :param image:
    :return:
    """
    nx, ny = image.shape
    RMSE = np.sqrt(1.0 / np.float64(nx * ny) * np.sum((oracle - image) * (oracle - image)))
    return RMSE


def calc_mae(oracle, image):
    """
    MAE ---> mean absolute error

    Notation:
        r(x,y)  --->  reference image
        t(x,y)  --->  test image
        nx , ny --->  number of rown and columns
    MAE = 1/( nx * ny ) * sum_{x}sum_{y} | r(x,y) - t(x,y) |
    :param oracle:
    :param image:
    :return:
    """
    nx, ny = image.shape
    MAE = 1.0 / np.float64(nx * ny) * np.sum(np.abs(oracle - image))
    return MAE


def compute_map_ssim(image1, image2, window_size, sigma):
    """
    Calculate map of SSIM values ---- structural similarity index
        Adaptation of the python code of Antoine Vacavant available at the webpage:
        http://isit.u-clermont1.fr/~anvacava/code.html
    Reference:
        Image Quality Assessment: From Error Visibility to Structural Similarity", Z.Whang, A.C.Bovik et al.,
        IEEE Transactions on image processing, Vol.13, No. 4, April 2004.
    :param image1:
    :param image2:
    :param window_size:
    :param sigma:
    :return:
    """
    # Initialize normalized circular-symmetric gaussian kernel
    print('\nInitialize gaussian kernel ....')
    gauss_kernel = init_gaussian_kernel(window_size, sigma)

    # Calculate maps of mean values
    # convolution perform with scipy.ndimage.filters.convolve because
    # numpy.convolve does not support convolutions between matrices;
    # scipy.ndimage.filters.convolve treats, by default, the boundaries
    # with pixel reflection (see reference literature)
    print('\nCalculating map of mean values ....')
    image_mean1 = scim.filters.convolve(image1, gauss_kernel)
    image_mean2 = scim.filters.convolve(image2, gauss_kernel)
    print('Shape of mean values map: ', image_mean1.shape)

    # compute the squares of the maps of mean values
    image_mean_sqr1 = image_mean1 * image_mean1
    image_mean_sqr2 = image_mean2 * image_mean2
    image_mean12 = image_mean1 * image_mean2

    # Calculate maps of standard deviations and map of covariances
    # convolution perform with scipy.ndimage.filters.convolve because
    # numpy.convolve does not support convolutions between matrices;
    # scipy.ndimage.filters.convolve treats, by default, the boundaries
    # with pixel reflection (see reference literature)
    print('\nCalculating map of standard deviations values ....')
    image_std1 = scim.filters.convolve(image1 * image1, gauss_kernel) - image_mean_sqr1
    image_std2 = scim.filters.convolve(image2 * image2, gauss_kernel) - image_mean_sqr2
    image_cov = scim.filters.convolve(image1 * image2, gauss_kernel) - image_mean12
    print('Shape of std values map: ', image_std1.shape)

    # Calculate map of SSIM indeces
    # Select the parameters C1, C2 to stabilize the calculation
    # of the SSIM indeces
    K = 0.001
    L = 0.5 * (np.abs(np.max(image1) - np.min(image1)) + np.abs(np.max(image2) - np.min(image2)))
    C1 = (K * L) ** 2
    C2 = (K * L) ** 2
    print('\nParameters to calculate the SSIM indeces:')
    print('K = ', K, '   L = ', L, '   C1 = ', C1, '   C2 = ', C2)

    # apply SSIM formula presented
    print('\nCalculating map of SSIM values ....')
    map_ssim = ((2 * image_mean12 + C1) * (2 * image_cov + C2)) / \
               ((image_mean_sqr1 + image_mean_sqr2 + C1) * (image_std1 + image_std2 + C2))
    print('.... calculation done!')

    # calculate the mean value of the SSIM map
    MSSIM = np.average(map_ssim)
    print('\nMSSIM (mean value of the SSIM map) = ', MSSIM)

    return map_ssim, MSSIM


if __name__ == '__main__':
    print('done')
